Conversation
|
!test |
|
!test |
Greptile SummaryThis PR implements the auto TMA transpose scheduler, introducing a new Key issues found:
Confidence Score: 3/5
Neither issue is triggered by the test suite as written, but both represent real failure modes for edge-case inputs or manually-constructed params. The code quality issues (test cout and typos) are minor.
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[TransposeScheduler::computeHeuristics] --> B{TmaTranspose enabled?}
B -- Yes --> C[tma::getTransposeHeuristics]
B -- No --> E[non_tma::getTransposeHeuristics]
C --> D{n_input > n_output?}
D -- Yes\nis_output_smem_transpose=true --> F[use_tma_load=true\nuse_tma_store=true\nswizzle on output smem]
D -- No\nis_output_smem_transpose=false --> G[use_tma_load=true\nuse_tma_store=false\nswizzle on input smem]
F --> H[Return TransposeParams]
G --> H
C -- returns null --> E
E --> H
H --> I[TransposeScheduler::schedule]
I --> J{use_tma_load OR use_tma_store?}
J -- Yes --> K[tma::scheduleTranspose]
J -- No --> L[non_tma::scheduleTranspose]
K --> M{is_output_smem_transpose?}
M -- true --> N[TMA load w/o swizzle\nTMA store w/ MmaSwizzle\ntranspose at regs→output smem]
M -- false --> O[TMA load w/ MmaSwizzle\nregister store\ntranspose at input smem→regs]
Last reviewed commit: bc772db |
csrc/scheduler/transpose_tma.cpp
Outdated
| NVF_ERROR(grouped_inputs_outputs.size() >= 2); | ||
|
|
||
| // When there are more inputs than outputs, output smem transpose should be | ||
| // used, however, if it is not, then input smem tranpose will be used, to |
There was a problem hiding this comment.
tranpose should be transpose
| const int64_t cta_per_sm = | ||
| dev_props->maxThreadsPerMultiProcessor / threads_per_cta; | ||
| const int64_t bytes_per_cta = bytes_per_sm / cta_per_sm; | ||
| const int64_t bytes_per_tile = bytes_per_cta / n_input; |
There was a problem hiding this comment.
Add check that n_input > 0 before this division. While the scheduler validation should prevent this, defensive programming would make the code more robust.
| const int64_t bytes_per_tile = bytes_per_cta / n_input; | |
| NVF_ERROR(n_input > 0, "Expected at least one TensorView input for transpose"); | |
| const int64_t bytes_per_tile = bytes_per_cta / n_input; |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
Review updated until commit bc772db Description
|
| Relevant files |
|---|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Potential TMA load restriction
This is more restrictive than the original which checked all loop domains. This could potentially exclude valid TMA loads where some dimensions have extent 1 but other dimensions are parallelized with threads. Need to verify this doesn't break existing TMA use cases. |
Additional Comments (2)
If On an H100 (maxThreadsPerMultiProcessor = 2048, cta_per_sm = 8, bytes_per_cta = 8192), this triggers when
The |
Additional Comments (4)
If This happens when While unlikely for typical transpose fusions (1–2 inputs), this is an unbounded loop with no guard. A simple fix is to initialise
Note the asymmetry: Step 3 already guards the analogous constraint with an explicit
These Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
To reduce number of tranpose ops,
is_output_smem_transposeis added to control input/output transpose:1. When there are more inputs than outputs,
is_output_smem_transpose = TrueTMA load without swizzle, TMA store with swizzle, transpose at
regs --> output cached smem2. When there are less inputs than outputs,
is_output_smem_transpose = FalseTMA load with swizzle, register store, transpose at
input cached smem -> regsCurrent performance is in this doc.